# Copyright 2025 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import os
import pytest
import requests
import yaml
from dataclasses import dataclass, field
from kserve import KServeClient, V1alpha1LLMInferenceService, constants
from kubernetes import client
from typing import Any, Callable, List, Optional

from .diagnostic import (
    print_all_events_table,
    kinds_matching_by_labels,
)
from .fixtures import (  # noqa: F401
    create_router_resources,
    delete_router_resources,
    generate_test_id,
    inject_k8s_proxy,
    # Factory functions are not called explicitly, but they need to be imported to work
    test_case,  # noqa: F811
)
from .test_resources import (
    ROUTER_GATEWAYS,
    ROUTER_ROUTES,
)
from .logging import log_execution

KSERVE_PLURAL_LLMINFERENCESERVICE = "llminferenceservices"


def assert_200(response: requests.Response) -> None:
    """Default response assertion that checks for 200 status code."""
    assert (
        response.status_code == 200
    ), f"Service returned {response.status_code}: {response.text}"


def assert_200_with_choices(response: requests.Response) -> None:
    """Assert 200 status code with choices in response."""
    assert (
        response.status_code == 200
        and response.json().get("choices") is not None
        and len(response.json().get("choices", [])) > 0
    ), f"Expected 200 with choices, got {response.status_code}: {response.text}"


@dataclass
class TestCase:
    __test__ = False  # So pytest will not try to execute it.
    """Test case configuration for LLM inference service tests."""
    base_refs: List[str]
    prompt: str
    service_name: Optional[str] = None
    max_tokens: int = 100
    response_assertion: Callable[[requests.Response], None] = assert_200
    wait_timeout: int = 900
    response_timeout: int = 60
    before_test: List[Callable[[], Any]] = field(default_factory=list)
    after_test: List[Callable[[], Any]] = field(default_factory=list)
    # Factory provided
    llm_service: V1alpha1LLMInferenceService = None  # Generated by llm_service_factory
    model_name: str = "default/model"  # This will be generated by the factory


@pytest.mark.llminferenceservice
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize(
    "test_case",
    [
        pytest.param(
            TestCase(
                base_refs=[
                    "router-managed",
                    "workload-single-cpu",
                    "model-fb-opt-125m",
                ],
                prompt="KServe is a",
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=["router-custom-route-timeout", "scheduler-managed",
                           "workload-single-cpu", "model-fb-opt-125m"],
                prompt="KServe is a",
                service_name="custom-route-timeout-test",
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=["router-with-refs", "scheduler-managed", "workload-single-cpu", "model-fb-opt-125m"],
                prompt="KServe is a",
                service_name="router-with-refs-test",
                before_test=[lambda: create_router_resources(
                    gateways=[ROUTER_GATEWAYS[0]],
                    routes=[ROUTER_ROUTES[0], ROUTER_ROUTES[1]],
                )],
                after_test=[lambda: delete_router_resources(
                    gateways=[ROUTER_GATEWAYS[0]],
                    routes=[ROUTER_ROUTES[0], ROUTER_ROUTES[1]],
                )],
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=["router-managed", "workload-pd-cpu", "model-fb-opt-125m"],
                prompt="You are an expert in Kubernetes-native machine learning serving platforms, with deep knowledge of the KServe project. "
                "Explain the challenges of serving large-scale models, GPU scheduling, and how KServe integrates with capabilities like multi-model serving. "
                "Provide a detailed comparison with open source alternatives, focusing on operational trade-offs.",
                response_assertion=assert_200_with_choices,
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=["router-custom-route-timeout-pd",
                           "scheduler-managed", "workload-pd-cpu", "model-fb-opt-125m"],
                prompt="You are an expert in Kubernetes-native machine learning serving platforms, with deep knowledge of the KServe project. "
                "Explain the challenges of serving large-scale models, GPU scheduling, and how KServe integrates with capabilities like multi-model serving. "
                "Provide a detailed comparison with open source alternatives, focusing on operational trade-offs.",
                service_name="custom-route-timeout-pd-test",
                response_assertion=assert_200_with_choices,
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=["router-with-refs-pd", "scheduler-managed", "workload-pd-cpu", "model-fb-opt-125m"],
                prompt="You are an expert in Kubernetes-native machine learning serving platforms, with deep knowledge of the KServe project. "
                "Explain the challenges of serving large-scale models, GPU scheduling, and how KServe integrates with capabilities like multi-model serving. "
                "Provide a detailed comparison with open source alternatives, focusing on operational trade-offs.",
                service_name="router-with-refs-pd-test",
                response_assertion=assert_200_with_choices,
                before_test=[lambda: create_router_resources(
                    gateways=[ROUTER_GATEWAYS[1]],
                    routes=[ROUTER_ROUTES[2], ROUTER_ROUTES[3]],
                )],
                after_test=[lambda: delete_router_resources(
                    gateways=[ROUTER_GATEWAYS[1]],
                    routes=[ROUTER_ROUTES[2], ROUTER_ROUTES[3]],
                )],
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_single_node],
        ),
        pytest.param(
            TestCase(
                base_refs=[
                    "router-managed",
                    "workload-dp-ep-gpu",
                    "workload-dp-ep-prefill-gpu",
                    "model-deepseek-v2-lite",
                ],
                prompt="Delve into the multifaceted implications of a fully disaggregated cloud architecture, specifically "
                "where the compute plane (P) and the data plane (D) are independently deployed and managed for a "
                "geographically distributed, high-throughput, low-latency microservices ecosystem. Beyond the "
                "fundamental challenges of network latency and data consistency, elaborate on the advanced "
                "considerations and trade-offs inherent in such a setup: 1. Network Architecture and Protocols: "
                "How would the network fabric and underlying protocols (e.g., RDMA, custom transport layers) need to "
                "evolve to support optimal performance and minimize inter-plane communication overhead, especially for "
                "synchronous operations? Discuss the role of network programmability (e.g., SDN, P4) in dynamically "
                "optimizing routing and traffic flow between P and D. 2. Advanced Data Consistency and Durability: "
                "Explore sophisticated data consistency models (e.g., causal consistency, strong eventual consistency) "
                "and their applicability in balancing performance and data integrity across a globally distributed data plane. "
                "Detail strategies for ensuring data durability and fault tolerance, including multi-region replication, "
                "intelligent partitioning, and recovery mechanisms in the event of partial or full plane failures. "
                "3. Dynamic Resource Orchestration and Cost Optimization: Analyze how an orchestration layer would intelligently "
                "manage the independent scaling of compute (P) and data (D) resources, considering fluctuating workloads, "
                "cost efficiency, and performance targets (e.g., using predictive analytics for resource provisioning). "
                "Discuss mechanisms for dynamically reallocating compute nodes to different data partitions based on "
                "workload patterns and data locality, potentially involving live migration strategies. "
                "4. Security and Compliance in a Distributed Landscape: Address the enhanced security perimeter "
                "challenges, including securing communication channels between P and D (encryption in transit, mutual TLS), "
                "fine-grained access control to data at rest and in motion, and identity management across disaggregated "
                "components. Discuss how such an architecture impacts compliance with regulatory frameworks (e.g., GDPR, HIPAA) "
                "concerning data sovereignty, privacy, and auditability. 5. Operational Complexity and Observability: "
                "Examine the increased complexity in monitoring, logging, and tracing across highly decoupled compute and "
                "data planes. What specialized tooling and practices (e.g., distributed tracing with OpenTelemetry, advanced AIOps) "
                "would be essential? How would incident response and troubleshooting differ in this disaggregated environment "
                "compared to traditional integrated systems? Consider the challenges of pinpointing root causes across "
                "independent failures. 6. Real-world Applicability and Future Trends: Identify specific industries "
                "or use cases (e.g., high-frequency trading, IoT edge processing, large language model inference) "
                "where the benefits of P/D disaggregation would strongly outweigh its complexities. "
                "Conclude by speculating on emerging technologies or paradigms (e.g., serverless compute functions "
                "directly interacting with object storage, in-memory disaggregation) that could further drive or "
                "transform P/D disaggregation in cloud computing.",
                max_tokens=2000,
            ),
            marks=[
                pytest.mark.cluster_gpu,
                pytest.mark.cluster_nvidia,
                pytest.mark.cluster_nvidia_roce,
            ],
        ),
        pytest.param(
            TestCase(
                base_refs=[
                    "router-no-scheduler",
                    "workload-single-cpu",
                    "model-fb-opt-125m",
                ],
                prompt="What is KServe?",
            ),
            marks=[
                pytest.mark.cluster_cpu,
                pytest.mark.cluster_single_node,
                pytest.mark.no_scheduler,
            ],
        ),
        pytest.param(
            TestCase(
                base_refs=[
                    "router-managed",
                    "workload-simulated-dp-ep-cpu",
                    "model-fb-opt-125m",
                ],
                prompt="This test simulates DP+EP that can run on CPU, the idea is to test the LWS-based deployment, "
                "but without the resources requirements for DP+EP (GPUs and ROCe/IB).",
            ),
            marks=[pytest.mark.cluster_cpu, pytest.mark.cluster_multi_node],
        ),
    ],
    indirect=["test_case"],
    ids=generate_test_id,
)
@log_execution
def test_llm_inference_service(test_case: TestCase):  # noqa: F811
    inject_k8s_proxy()

    kserve_client = KServeClient(
        config_file=os.environ.get("KUBECONFIG", "~/.kube/config"),
        client_configuration=client.Configuration(),
    )

    service_name = test_case.llm_service.metadata.name

    try:
        create_llmisvc(kserve_client, test_case.llm_service)
        wait_for_llm_isvc_ready(
            kserve_client, test_case.llm_service, test_case.wait_timeout
        )
        wait_for_model_response(kserve_client, test_case, test_case.wait_timeout)
    except Exception as e:
        print(f"❌ ERROR: Failed to call llm inference service {service_name}: {e}")
        _collect_diagnostics(kserve_client, test_case.llm_service)
        raise
    finally:
        try:
            if os.getenv("SKIP_RESOURCE_DELETION", "False").lower() in (
                "false",
                "0",
                "f",
            ):
                delete_llmisvc(kserve_client, test_case.llm_service)
        except Exception as e:
            print(f"⚠️ Warning: Failed to cleanup service {service_name}: {e}")


@log_execution
def create_llmisvc(kserve_client: KServeClient, llm_isvc: V1alpha1LLMInferenceService):
    try:
        outputs = kserve_client.api_instance.create_namespaced_custom_object(
            constants.KSERVE_GROUP,
            llm_isvc.api_version.split("/")[1],
            llm_isvc.metadata.namespace,
            KSERVE_PLURAL_LLMINFERENCESERVICE,
            llm_isvc,
        )
        print(f"✅ LLM inference service {llm_isvc.metadata.name} created successfully")
        return outputs
    except client.rest.ApiException as e:
        raise RuntimeError(
            f"❌ Exception when calling CustomObjectsApi->"
            f"create_namespaced_custom_object for LLMInferenceService: {e}"
        ) from e


@log_execution
def delete_llmisvc(kserve_client: KServeClient, llm_isvc: V1alpha1LLMInferenceService):
    try:
        result = kserve_client.api_instance.delete_namespaced_custom_object(
            constants.KSERVE_GROUP,
            llm_isvc.api_version.split("/")[1],
            llm_isvc.metadata.namespace,
            KSERVE_PLURAL_LLMINFERENCESERVICE,
            llm_isvc.metadata.name,
        )
        print(f"✅ LLM inference service {llm_isvc.metadata.name} deleted successfully")
        return result
    except client.rest.ApiException as e:
        raise RuntimeError(
            f"❌ Exception when calling CustomObjectsApi->"
            f"delete_namespaced_custom_object for LLMInferenceService: {e}"
        ) from e


@log_execution
def get_llmisvc(
    kserve_client: KServeClient,
    name,
    namespace,
    version=constants.KSERVE_V1ALPHA1_VERSION,
):
    try:
        return kserve_client.api_instance.get_namespaced_custom_object(
            constants.KSERVE_GROUP,
            version,
            namespace,
            KSERVE_PLURAL_LLMINFERENCESERVICE,
            name,
        )
    except client.rest.ApiException as e:
        raise RuntimeError(
            f"❌ Exception when calling CustomObjectsApi->"
            f"get_namespaced_custom_object for LLMInferenceService: {e}"
        ) from e


@log_execution
def wait_for_model_response(
    kserve_client: KServeClient,
    test_case: TestCase,  # noqa: F811
    timeout_seconds: int = 900,
) -> str:
    def assert_model_responds():
        try:
            service_url = get_llm_service_url(kserve_client, test_case.llm_service)
        except Exception as e:
            raise AssertionError(f"❌ Failed to get service URL: {e}") from e

        completion_url = f"{service_url}/v1/completions"
        test_payload = {
            "model": test_case.model_name,
            "prompt": test_case.prompt,
            "max_tokens": test_case.max_tokens,
        }
        print(f"Calling LLM service at {completion_url} with payload {test_payload}")
        try:
            response = requests.post(
                completion_url,
                headers={"Content-Type": "application/json"},
                json=test_payload,
                timeout=test_case.response_timeout,
            )
        except Exception as e:
            raise AssertionError(f"❌ Failed to call model: {e}") from e

        test_case.response_assertion(response)
        return response.text[: test_case.max_tokens]

    return wait_for(assert_model_responds, timeout=timeout_seconds, interval=5.0)


def get_llm_service_url(
    kserve_client: KServeClient, llm_isvc: V1alpha1LLMInferenceService
):
    service_name = llm_isvc.metadata.name

    try:
        llm_isvc = get_llmisvc(
            kserve_client,
            llm_isvc.metadata.name,
            llm_isvc.metadata.namespace,
            llm_isvc.api_version.split("/")[1],
        )

        if "status" not in llm_isvc:
            raise ValueError(
                f"❌ No status found in LLM inference service {service_name} status: {llm_isvc}"
            )

        status = llm_isvc["status"]

        if "url" in status and status["url"]:
            return status["url"]

        if (
            "addresses" in status
            and status["addresses"]
            and len(status["addresses"]) > 0
        ):
            first_address = status["addresses"][0]
            if "url" in first_address:
                return first_address["url"]

        raise ValueError(
            f"❌ No URL found in LLM inference service {service_name} status"
        )

    except Exception as e:
        raise ValueError(
            f"❌ Failed to get URL for LLM inference service {service_name}: {e}"
        ) from e


@log_execution
def wait_for_llm_isvc_ready(
    kserve_client: KServeClient,
    given: V1alpha1LLMInferenceService,
    timeout_seconds: int = 900,
) -> str:
    def assert_llm_isvc_ready():
        out = get_llmisvc(
            kserve_client,
            given.metadata.name,
            given.metadata.namespace,
            given.api_version.split("/")[1],
        )

        if "status" not in out:
            raise AssertionError("No status found in LLM inference service")

        status = out["status"]
        if "conditions" not in status:
            raise AssertionError("No conditions found in status")

        expected_true_conditions = {"Ready", "WorkloadsReady", "RouterReady"}
        got_true_conditions = set()

        conditions = status["conditions"]

        for condition in conditions:
            if condition.get("status") == "True":
                got_true_conditions.add(condition.get("type"))

        missing_conditions = expected_true_conditions - got_true_conditions
        if missing_conditions:
            raise AssertionError(
                f"Missing true conditions: {missing_conditions}, expected {expected_true_conditions}, got {conditions}"
            )
        return True

    return wait_for(assert_llm_isvc_ready, timeout=timeout_seconds, interval=1.0)


def wait_for(
    assertion_fn: Callable[[], Any], timeout: float = 5.0, interval: float = 0.1
) -> Any:
    """Wait for the assertion to succeed within timeout."""
    deadline = time.time() + timeout
    while True:
        try:
            return assertion_fn()
        except AssertionError:
            if time.time() >= deadline:
                raise
            time.sleep(interval)


def _collect_diagnostics(
    kserve_client: KServeClient, llm_isvc: V1alpha1LLMInferenceService
):
    name = llm_isvc.metadata.name
    ns = llm_isvc.metadata.namespace

    svc = get_llmisvc(kserve_client, name, ns)

    labels = {
        "app.kubernetes.io/part-of": "llminferenceservice",
        "app.kubernetes.io/name": name,
    }

    print(f"🔍 # Diagnostics for {name!r} in {ns!r}")
    print("---")
    print(f"# LLMInferenceService {name}")
    try:
        print(yaml.safe_dump(svc, sort_keys=False))
    except Exception as e:
        print(f"# ❌ failed to dump LLMInferenceService: {e}")

    print_all_events_table(ns)

    all_resources = kinds_matching_by_labels(ns, labels)
    for obj in all_resources:
        print("---")
        print(yaml.safe_dump(obj.to_dict(), sort_keys=False))
